This Vignette is designed to demonstate the functionality and limitations of the AlePlotPlus package
First, let’s demonstrate the package on a simple, simulated dataset
# set the number of observations
n = 1000
b0 = 0
b1 = 10
b2 = 1
b3 = 10
b12 = 0
b13 = 10
b23 = 0
#simulate values of x1, x2, and x3 from a normal dist
x1 = rnorm(n, mean = 0, sd = 1)
x2 = rnorm(n, mean = 0, sd = 1)
x3 = rnorm(n, mean = 0, sd = 1)
# calculate noise
eps = rnorm(n, mean = 0,sd = .001)
#estimate response based on underlying model
y <- b0 + b1*x1 + b2*x2 + b3*x3 + b12*x1*x2 + b13*x1*x3 + b23*x3*x3 + eps
df = data.frame(x1 = x1,x2 = x2,x3 = x3, y = y)
X = df[,-4]
mdl = randomForest::randomForest(y ~.,data = df)
#commented out cversion of code to see the code without the plot pollution the original aleplot functions produce.
# pred.fun <- function(X.model, newdata) {
# prd_out = predict(X.model, newdata)
# return(prd_out)
# }
# var_imps = calc_ALE_varimp(X = X,K = 40, MODEL = mdl,pred.fun = pred.fun)
# MainEffectALES = ALEPlotPlus::mkALEplots(X = X,K = 40, X.MODEL = mdl,pred.fun = pred.fun)
var_imps
## var varimp
## 1 x1 6.8782308
## 2 x2 0.5351686
## 3 x3 7.0270039
MainEffectALES
## [[1]]
##
## [[2]]
##
## [[3]]
The results correctly estimate that x1 and x3 have roughly 10x the effect on the result as x2. The ALEPlots are generally correctly describe the linear relationship between each predictor and the outcome, although they are less accurate at the extreme values, and for variables less important to the model. The erratic behavior at these extremes justifies the utility of this tool, which by showing the density of points used to estimate the effect at each value, can help you estimate of how confident you can be in the ALE Plot at those points.
Now let’s examine the interactions
# commented out code to show work without polluting the markdown withoutput of orig aleplot functions
# ALEints = ALEPlotPlus::find_ALE_interacts(X,model = mdl,pred_fun = pred.fun,K = 40)
# ALEPlotPlus::create_2D_ALEs(X,model = mdl,pred_fun=pred.fun,K = 40,savedir = "~/exampleplots")
clrs = QckRBrwrPllt(name = "OrRd",n = 100)
ints4heatmap = ALEPlotPlus::interacts2_heatmap(ALEints)
ints4heatmap
## x1 x2 x3
## x1 NA 0.6951276 4.8407476
## x2 0.6951276 NA 0.8910706
## x3 4.8407476 0.8910706 NA
pheatmap::pheatmap(ints4heatmap,color = clrs)
knitr::include_graphics("~/exampleplots/x1x2.jpeg")
knitr::include_graphics("~/exampleplots/x1x3.jpeg")
AlePlotPlus correctly identifies the interaction
The Pima Indians Diabetes dataset originates from the National Institute of Diabetes and Digestive and Kidney Diseases. The objective of the dataset is to diagnostically predict whether or not a patient has diabetes, based on certain diagnostic measurements included in the dataset.
The datasets consist of several medical predictor (independent)
variables and one target (dependent) variable, Outcome.
Predictor variables include the number of pregnancies the patient has
had, their BMI, insulin level, age, and more.
Here’s a detailed breakdown of the variables in the dataset:
Pregnancies: Number of times pregnant.Glucose: Plasma glucose concentration a 2 hours in an
oral glucose tolerance test.BloodPressure: Diastolic blood pressure (mm Hg).SkinThickness: Triceps skin fold thickness (mm).Insulin: 2-Hour serum insulin (mu U/ml).BMI: Body mass index (weight in kg/(height in
m)^2).DiabetesPedigreeFunction: Diabetes pedigree function (a
function that scores likelihood of diabetes based on family
history).Age: Age in years.Outcome: Class variable (0 if non-diabetic, 1 if
diabetic).# commented out code to show the code without the bloat the original ale functios produce polluting the markdown
# set.seed(1500)
#
#
# replace_0_with_mean <- function(x) {
# mean_x <- mean(x, na.rm = TRUE)
# ifelse(x == 0, mean_x, x)
# }
# # Perform data processing
# data(PimaIndiansDiabetes)
# pimi = PimaIndiansDiabetes
# pimi$diabetes = as.factor(pimi$diabetes == "pos")
# #pimi = pimi %>% mutate_at(c("mass","glucose","insulin"),replace_0_with_mean)
# pimi = pimi %>% filter(mass > 0) %>% filter(glucose > 0) %>% filter(insulin > 0) %>% filter(triceps >0) %>% filter(pressure > 0)
# X = pimi %>% dplyr::select(-diabetes)
#
# pimi.model = randomForest::randomForest(diabetes ~ .,pimi,ntree = 1000)
#
# pred_fun <- function(X.model, newdata) {
# prd_out = predict(X.model, newdata, type="prob")
# return(prd_out[,2])
# }
#
# var_imps = calc_ALE_varimp(X = X,K = 40, MODEL = pimi.model,pred.fun = pred_fun)
#
# pimi_ints_heatmap = find_ALE_interacts(X = X,model = pimi.model,pred_fun=pred_fun,K = 40) %>%
# interacts2_heatmap()
#
# ALEPlotPlus::create_2D_ALEs(X = X, model = pimi.model, pred_fun = pred_fun, K = 40,
# savedir = "~/pimitest")
var_imps
## var varimp
## 1 pregnant 0.02062627
## 2 glucose 0.15278211
## 3 pressure 0.01570669
## 4 triceps 0.02937670
## 5 insulin 0.07192683
## 6 mass 0.05321226
## 7 pedigree 0.04599836
## 8 age 0.08027492
pheatmap::pheatmap(pimi_ints_heatmap,cluster_rows=F,cluster_cols=F,
color = QckRBrwrPllt("OrRd",100),display_numbers=T,
number_format="%.3f")
knitr::include_graphics("~/pimitest/glucoseage.jpeg")
knitr::include_graphics("~/pimitest/insulinage.jpeg")
The interaction heatmap suggests some of the most significant interactions are between glucose and age and insulin and age. Let’s explore them both. The 2D ALEPlots show that for patients over 30, patients with high glucose and insulin are more at risk for diabetes than either predictor would predict alone, indicating a synergy. This reflects the known biology that older people in poor metabolic health are at high risk of of Type 2 diabetes.
Overall, this demonstrates the ability of this approach to summarize 2D interactions quantified by ALE Plots in a heatmap which can be followed up on.
Now, let’s break the approach using simulated data to illustrate its limitations and pitfalls. Let’s fit a model with two highly correlated variables.
# set the number of observations
n = 10000
b0 = 0
b1 = 10
b2 = 1
b3 = 10
b12 = 10
b13 = 10
b23 = 0
#simulate values of x1, x2, and x3 from a normal dist
x1 = rnorm(n, mean = 0, sd = 1)
x2 = rnorm(n, mean = 0, sd = 1)
x3 = x1 + rnorm(n, mean = 3, sd = 1)
# calculate noise
eps = rnorm(n, mean = 0,sd = .001)
#estimate response based on underlying model
y <- b0 + b1*x1 + b2*x2 + b3*x3 + b12*x1*x2 + b13*x1*x3 + b23*x3*x3 + eps
df = data.frame(x1 = x1,x2 = x2,x3 = x3, y = y)
X = df[,-4]
mdl = randomForest::randomForest(y ~.,data = df)
# commented out code so you can see the process without needing to see the bloat produced by the aleplot func.
# pred.fun <- function(X.model, newdata) {
# prd_out = predict(X.model, newdata)
# return(prd_out)
# }
# var_imps = calc_ALE_varimp(X = X,K = 40, MODEL = mdl,pred.fun = pred.fun)
# MainEffectALES = ALEPlotPlus::mkALEplots(X = X,K = 40, X.MODEL = mdl,pred.fun = pred.fun)
# create_2D_ALEs(X,model = mdl,pred_fun=pred.fun,K = 40,savedir = "~/exampleplots_correlated")
clrs = QckRBrwrPllt(name = "OrRd",n = 100)
ints4heatmap = ALEPlotPlus::interacts2_heatmap(ALEints)
ints4heatmap
## x1 x2 x3
## x1 NA 0.6951276 4.8407476
## x2 0.6951276 NA 0.8910706
## x3 4.8407476 0.8910706 NA
pheatmap::pheatmap(ints4heatmap,color = clrs,cluster_rows=F,cluster_cols=F)
MainEffectALES
## [[1]]
##
## [[2]]
##
## [[3]]
knitr::include_graphics("~/exampleplots_correlated/x1x2.jpeg")
knitr::include_graphics("~/exampleplots_correlated/x1x3.jpeg")
Note how despite b12 having the same coefficient as b13, the model describes one interaction as significantly more important than the other. This is because ALE plots have difficulties distinguishing the interaction between x1 and x3 from their interaction. Note the linear model below does not suffer from this issue. Notice also how it incorrectly estimates the shape of the main effect of x3.
summary(lm(y ~ x1*x2 + x2*x3 + x1*x3,data = df))
##
## Call:
## lm(formula = y ~ x1 * x2 + x2 * x3 + x1 * x3, data = df)
##
## Residuals:
## Min 1Q Median 3Q Max
## -0.0037471 -0.0006805 0.0000064 0.0006912 0.0042736
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 1.901e-05 3.263e-05 5.820e-01 0.5602
## x1 1.000e+01 2.300e-05 4.348e+05 <2e-16 ***
## x2 9.999e-01 3.161e-05 3.163e+04 <2e-16 ***
## x3 1.000e+01 1.017e-05 9.835e+05 <2e-16 ***
## x1:x2 1.000e+01 1.410e-05 7.094e+05 <2e-16 ***
## x2:x3 2.285e-05 1.004e-05 2.276e+00 0.0229 *
## x1:x3 1.000e+01 6.062e-06 1.650e+06 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 0.001006 on 9993 degrees of freedom
## Multiple R-squared: 1, Adjusted R-squared: 1
## F-statistic: 4.738e+12 on 6 and 9993 DF, p-value: < 2.2e-16
Next, let’s investigate what happens when main effects are vastly different
# set the number of observations
n = 10000
b0 = 0
b1 = 10
b2 = 100000
b3 = 10
b12 = 10
b13 = 10
b23 = 0
#simulate values of x1, x2, and x3 from a normal dist
x1 = rnorm(n, mean = 0, sd = 1)
x2 = rnorm(n, mean = 0, sd = 1)
x3 = rnorm(n, mean = 0, sd = 1)
# calculate noise
eps = rnorm(n, mean = 0,sd = .001)
#estimate response based on underlying model
y <- b0 + b1*x1 + b2*x2 + b3*x3 + b12*x1*x2 + b13*x1*x3 + b23*x3*x3 + eps
df = data.frame(x1 = x1,x2 = x2,x3 = x3, y = y)
X = df[,-4]
mdl = randomForest::randomForest(y ~.,data = df)
clrs = QckRBrwrPllt(name = "OrRd",n = 100)
ints4heatmap = ALEPlotPlus::interacts2_heatmap(ALEints)
ints4heatmap
## x1 x2 x3
## x1 NA 1843.182 1491.891
## x2 1843.182 NA 1420.451
## x3 1491.891 1420.451 NA
pheatmap::pheatmap(ints4heatmap,color = clrs,cluster_rows=F,cluster_cols=F)
knitr::include_graphics("~/exampleplots_maineffectdiff/x2x3.jpeg")
knitr::include_graphics("~/exampleplots_maineffectdiff/x1x3.jpeg")
As shown, the effect of X2 is so dominate it interferes with the estimates of the effects of the other two variables. linear models do not suffer this limitation.
var_imps
## var varimp
## 1 x1 757.3964
## 2 x2 70923.1962
## 3 x3 265.0766
summary(lm(y ~ x1*x2 + x2*x3 + x1*x3,data = df))
##
## Call:
## lm(formula = y ~ x1 * x2 + x2 * x3 + x1 * x3, data = df)
##
## Residuals:
## Min 1Q Median 3Q Max
## -0.0043279 -0.0006636 -0.0000015 0.0006699 0.0034032
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 7.537e-06 1.000e-05 7.530e-01 0.4513
## x1 1.000e+01 1.013e-05 9.873e+05 <2e-16 ***
## x2 1.000e+05 1.001e-05 9.994e+09 <2e-16 ***
## x3 1.000e+01 1.011e-05 9.891e+05 <2e-16 ***
## x1:x2 1.000e+01 1.014e-05 9.864e+05 <2e-16 ***
## x2:x3 1.900e-05 1.016e-05 1.870e+00 0.0615 .
## x1:x3 1.000e+01 1.018e-05 9.823e+05 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 0.001 on 9993 degrees of freedom
## Multiple R-squared: 1, Adjusted R-squared: 1
## F-statistic: 1.666e+19 on 6 and 9993 DF, p-value: < 2.2e-16
Overall we show that this approach can be used to quantify main effect and interaction strength in machine learning models. However, it possess limitations when variables are highly correlated and when interactions are dwarfed by strong main effects.